{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('test.conll.txt') as fopen:\n",
    "    corpus = fopen.read().split('\\n')\n",
    "    \n",
    "with open('dev.conll.txt') as fopen:\n",
    "    corpus_test = fopen.read().split('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "tag2idx = {'PAD': 0}\n",
    "char2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "word_idx = 3\n",
    "tag_idx = 1\n",
    "char_idx = 3\n",
    "\n",
    "def process_corpus(corpus, until = None):\n",
    "    global word2idx, tag2idx, char2idx, word_idx, tag_idx, char_idx\n",
    "    sentences, words, depends, labels = [], [], [], []\n",
    "    temp_sentence, temp_word, temp_depend, temp_label = [], [], [], []\n",
    "    for sentence in corpus:\n",
    "        if len(sentence):\n",
    "            sentence = sentence.split('\\t')\n",
    "            for c in sentence[1]:\n",
    "                if c not in char2idx:\n",
    "                    char2idx[c] = char_idx\n",
    "                    char_idx += 1\n",
    "            if sentence[7] not in tag2idx:\n",
    "                tag2idx[sentence[7]] = tag_idx\n",
    "                tag_idx += 1\n",
    "            if sentence[1] not in word2idx:\n",
    "                word2idx[sentence[1]] = word_idx\n",
    "                word_idx += 1\n",
    "            temp_word.append(word2idx[sentence[1]])\n",
    "            temp_depend.append(int(sentence[6]))\n",
    "            temp_label.append(tag2idx[sentence[7]])\n",
    "            temp_sentence.append(sentence[1])\n",
    "        else:\n",
    "            words.append(temp_word)\n",
    "            depends.append(temp_depend)\n",
    "            labels.append(temp_label)\n",
    "            sentences.append(temp_sentence)\n",
    "            temp_word = []\n",
    "            temp_depend = []\n",
    "            temp_label = []\n",
    "            temp_sentence = []\n",
    "    return sentences[:-1], words[:-1], depends[:-1], labels[:-1]\n",
    "        \n",
    "sentences, words, depends, labels = process_corpus(corpus)\n",
    "sentences_test, words_test, depends_test, labels_test = process_corpus(corpus_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.sequence import pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "words = pad_sequences(words,padding='post')\n",
    "depends = pad_sequences(depends,padding='post')\n",
    "labels = pad_sequences(labels,padding='post')\n",
    "\n",
    "words_test = pad_sequences(words_test,padding='post')\n",
    "depends_test = pad_sequences(depends_test,padding='post')\n",
    "labels_test = pad_sequences(labels_test,padding='post')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1700, 118)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "words_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_char_seq(batch, UNK = 2):\n",
    "    maxlen_c = max([len(k) for k in batch])\n",
    "    x = [[len(i) for i in k] for k in batch]\n",
    "    maxlen = max([j for i in x for j in i])\n",
    "    temp = np.zeros((len(batch),maxlen_c,maxlen),dtype=np.int32)\n",
    "    for i in range(len(batch)):\n",
    "        for k in range(len(batch[i])):\n",
    "            for no, c in enumerate(batch[i][k]):\n",
    "                temp[i,k,-1-no] = char2idx.get(c, UNK)\n",
    "    return temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx2word = {idx: tag for tag, idx in word2idx.items()}\n",
    "idx2tag = {i: w for w, i in tag2idx.items()}\n",
    "\n",
    "train_X = words\n",
    "train_Y = labels\n",
    "train_depends = depends\n",
    "train_char = generate_char_seq(sentences)\n",
    "\n",
    "test_X = words_test\n",
    "test_Y = labels_test\n",
    "test_depends = depends_test\n",
    "test_char = generate_char_seq(sentences_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim_word,\n",
    "        dim_char,\n",
    "        dropout,\n",
    "        learning_rate,\n",
    "        hidden_size_char,\n",
    "        hidden_size_word,\n",
    "        num_layers,\n",
    "        maxlen\n",
    "    ):\n",
    "        def cells(size, reuse = False):\n",
    "            return tf.contrib.rnn.DropoutWrapper(\n",
    "                tf.nn.rnn_cell.LSTMCell(\n",
    "                    size,\n",
    "                    initializer = tf.orthogonal_initializer(),\n",
    "                    reuse = reuse,\n",
    "                ),\n",
    "                output_keep_prob = dropout,\n",
    "            )\n",
    "\n",
    "        def bahdanau(embedded, size):\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n",
    "                num_units = hidden_size_word, memory = embedded\n",
    "            )\n",
    "            return tf.contrib.seq2seq.AttentionWrapper(\n",
    "                cell = cells(hidden_size_word),\n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = hidden_size_word,\n",
    "            )\n",
    "\n",
    "        self.word_ids = tf.placeholder(tf.int32, shape = [None, None])\n",
    "        self.char_ids = tf.placeholder(tf.int32, shape = [None, None, None])\n",
    "        self.labels = tf.placeholder(tf.int32, shape = [None, None])\n",
    "        self.depends = tf.placeholder(tf.int32, shape = [None, None])\n",
    "        self.maxlen = tf.shape(self.word_ids)[1]\n",
    "        self.lengths = tf.count_nonzero(self.word_ids, 1)\n",
    "\n",
    "        self.word_embeddings = tf.Variable(\n",
    "            tf.truncated_normal(\n",
    "                [len(word2idx), dim_word], stddev = 1.0 / np.sqrt(dim_word)\n",
    "            )\n",
    "        )\n",
    "        self.char_embeddings = tf.Variable(\n",
    "            tf.truncated_normal(\n",
    "                [len(char2idx), dim_char], stddev = 1.0 / np.sqrt(dim_char)\n",
    "            )\n",
    "        )\n",
    "\n",
    "        word_embedded = tf.nn.embedding_lookup(\n",
    "            self.word_embeddings, self.word_ids\n",
    "        )\n",
    "        char_embedded = tf.nn.embedding_lookup(\n",
    "            self.char_embeddings, self.char_ids\n",
    "        )\n",
    "        s = tf.shape(char_embedded)\n",
    "        char_embedded = tf.reshape(\n",
    "            char_embedded, shape = [s[0] * s[1], s[-2], dim_char]\n",
    "        )\n",
    "\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (\n",
    "                state_fw,\n",
    "                state_bw,\n",
    "            ) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = cells(hidden_size_char),\n",
    "                cell_bw = cells(hidden_size_char),\n",
    "                inputs = char_embedded,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_char_%d' % (n),\n",
    "            )\n",
    "            char_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "        output = tf.reshape(\n",
    "            char_embedded[:, -1], shape = [s[0], s[1], 2 * hidden_size_char]\n",
    "        )\n",
    "        word_embedded = tf.concat([word_embedded, output], axis = -1)\n",
    "\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (\n",
    "                state_fw,\n",
    "                state_bw,\n",
    "            ) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = bahdanau(word_embedded, hidden_size_word),\n",
    "                cell_bw = bahdanau(word_embedded, hidden_size_word),\n",
    "                inputs = word_embedded,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_word_%d' % (n),\n",
    "            )\n",
    "            word_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "\n",
    "        logits = tf.layers.dense(word_embedded, len(idx2tag))\n",
    "        logits_depends = tf.layers.dense(word_embedded, maxlen)\n",
    "        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(\n",
    "            logits, self.labels, self.lengths\n",
    "        )\n",
    "        with tf.variable_scope(\"depends\"):\n",
    "            log_likelihood_depends, transition_params_depends = tf.contrib.crf.crf_log_likelihood(\n",
    "                logits_depends, self.depends, self.lengths\n",
    "            )\n",
    "        self.cost = tf.reduce_mean(-log_likelihood) + tf.reduce_mean(-log_likelihood_depends)\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        \n",
    "        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)\n",
    "        \n",
    "        self.tags_seq, _ = tf.contrib.crf.crf_decode(\n",
    "            logits, transition_params, self.lengths\n",
    "        )\n",
    "        self.tags_seq_depends, _ = tf.contrib.crf.crf_decode(\n",
    "            logits_depends, transition_params_depends, self.lengths\n",
    "        )\n",
    "\n",
    "        self.prediction = tf.boolean_mask(self.tags_seq, mask)\n",
    "        mask_label = tf.boolean_mask(self.labels, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "        self.prediction = tf.boolean_mask(self.tags_seq_depends, mask)\n",
    "        mask_label = tf.boolean_mask(self.depends, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy_depends = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "dim_word = 128\n",
    "dim_char = 256\n",
    "dropout = 1\n",
    "learning_rate = 1e-3\n",
    "hidden_size_char = 64\n",
    "hidden_size_word = 64\n",
    "num_layers = 2\n",
    "batch_size = 32\n",
    "\n",
    "model = Model(dim_word,dim_char,dropout,learning_rate,hidden_size_char,hidden_size_word,num_layers,\n",
    "             words.shape[1])\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:40<00:00,  2.05it/s, accuracy=0.106, accuracy_depends=0.14, cost=105]   \n",
      "test minibatch loop: 100%|██████████| 54/54 [00:16<00:00,  3.87it/s, accuracy=0.136, accuracy_depends=0.0455, cost=164]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 56.09716725349426\n",
      "epoch: 0, training loss: 150.300666, training acc: 0.128670, training depends: 0.081010, valid loss: 141.028491, valid acc: 0.147059, valid depends: 0.120955\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.06it/s, accuracy=0.298, accuracy_depends=0.161, cost=90.5]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.83it/s, accuracy=0.309, accuracy_depends=0.118, cost=143] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.58958339691162\n",
      "epoch: 1, training loss: 130.642711, training acc: 0.227523, training depends: 0.128279, valid loss: 124.700015, valid acc: 0.309620, valid depends: 0.129380\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.517, accuracy_depends=0.195, cost=74.7]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.83it/s, accuracy=0.518, accuracy_depends=0.145, cost=121] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.429301023483276\n",
      "epoch: 2, training loss: 109.336289, training acc: 0.440639, training depends: 0.154167, valid loss: 105.495260, valid acc: 0.495497, valid depends: 0.153521\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.562, accuracy_depends=0.236, cost=67.6]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.83it/s, accuracy=0.573, accuracy_depends=0.2, cost=115]   \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.509848833084106\n",
      "epoch: 3, training loss: 96.279404, training acc: 0.549312, training depends: 0.185349, valid loss: 99.562834, valid acc: 0.568718, valid depends: 0.161316\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.682, accuracy_depends=0.257, cost=59.7]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.84it/s, accuracy=0.627, accuracy_depends=0.2, cost=105]   \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.52307176589966\n",
      "epoch: 4, training loss: 86.934430, training acc: 0.639515, training depends: 0.214143, valid loss: 90.451283, valid acc: 0.642919, valid depends: 0.181113\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.09it/s, accuracy=0.747, accuracy_depends=0.274, cost=53]  \n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.88it/s, accuracy=0.673, accuracy_depends=0.209, cost=99.6]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.530587673187256\n",
      "epoch: 5, training loss: 79.217252, training acc: 0.707435, training depends: 0.240276, valid loss: 85.398946, valid acc: 0.691384, valid depends: 0.198120\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.801, accuracy_depends=0.353, cost=48.9]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.91it/s, accuracy=0.709, accuracy_depends=0.182, cost=94.8]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.583584785461426\n",
      "epoch: 6, training loss: 72.303042, training acc: 0.762662, training depends: 0.274533, valid loss: 82.404467, valid acc: 0.727524, valid depends: 0.189771\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.09it/s, accuracy=0.805, accuracy_depends=0.377, cost=43.4]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.87it/s, accuracy=0.727, accuracy_depends=0.191, cost=90.9]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.58053755760193\n",
      "epoch: 7, training loss: 65.055744, training acc: 0.798943, training depends: 0.321437, valid loss: 80.717043, valid acc: 0.746362, valid depends: 0.196639\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.09it/s, accuracy=0.853, accuracy_depends=0.455, cost=38.9]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.86it/s, accuracy=0.745, accuracy_depends=0.209, cost=93.3]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.67471981048584\n",
      "epoch: 8, training loss: 58.739642, training acc: 0.827087, training depends: 0.377910, valid loss: 81.661547, valid acc: 0.749696, valid depends: 0.195816\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.866, accuracy_depends=0.527, cost=35.2]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.85it/s, accuracy=0.727, accuracy_depends=0.145, cost=101] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.61992311477661\n",
      "epoch: 9, training loss: 54.076346, training acc: 0.848619, training depends: 0.417288, valid loss: 80.947128, valid acc: 0.767324, valid depends: 0.209349\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.09it/s, accuracy=0.904, accuracy_depends=0.507, cost=33]  \n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.90it/s, accuracy=0.782, accuracy_depends=0.209, cost=91.2]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.548739194869995\n",
      "epoch: 10, training loss: 50.326555, training acc: 0.863248, training depends: 0.458952, valid loss: 79.820367, valid acc: 0.774822, valid depends: 0.222942\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.06it/s, accuracy=0.911, accuracy_depends=0.558, cost=30.4]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.83it/s, accuracy=0.791, accuracy_depends=0.227, cost=89.9]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.604267597198486\n",
      "epoch: 11, training loss: 45.569131, training acc: 0.877704, training depends: 0.509152, valid loss: 80.193576, valid acc: 0.779312, valid depends: 0.218611\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.928, accuracy_depends=0.688, cost=23.6]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.85it/s, accuracy=0.791, accuracy_depends=0.145, cost=91.4]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.518343448638916\n",
      "epoch: 12, training loss: 41.137693, training acc: 0.893106, training depends: 0.548518, valid loss: 82.710994, valid acc: 0.784206, valid depends: 0.220646\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:40<00:00,  2.08it/s, accuracy=0.935, accuracy_depends=0.678, cost=24.9]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.84it/s, accuracy=0.809, accuracy_depends=0.164, cost=97.1]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.76531481742859\n",
      "epoch: 13, training loss: 37.963494, training acc: 0.906679, training depends: 0.583725, valid loss: 82.073511, valid acc: 0.782869, valid depends: 0.243221\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.10it/s, accuracy=0.942, accuracy_depends=0.733, cost=18.8]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.86it/s, accuracy=0.836, accuracy_depends=0.145, cost=101] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.35486125946045\n",
      "epoch: 14, training loss: 34.554710, training acc: 0.917006, training depends: 0.620208, valid loss: 85.657195, valid acc: 0.784520, valid depends: 0.241463\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.10it/s, accuracy=0.966, accuracy_depends=0.781, cost=15.8]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.89it/s, accuracy=0.864, accuracy_depends=0.173, cost=103] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.17195224761963\n",
      "epoch: 15, training loss: 31.398126, training acc: 0.924450, training depends: 0.656583, valid loss: 85.506386, valid acc: 0.793804, valid depends: 0.255798\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.06it/s, accuracy=0.935, accuracy_depends=0.774, cost=15.4]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.83it/s, accuracy=0.818, accuracy_depends=0.164, cost=98.7]\n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.265546560287476\n",
      "epoch: 16, training loss: 28.956760, training acc: 0.932248, training depends: 0.678523, valid loss: 84.795803, valid acc: 0.796438, valid depends: 0.264498\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.10it/s, accuracy=0.959, accuracy_depends=0.75, cost=16.2] \n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.90it/s, accuracy=0.8, accuracy_depends=0.164, cost=111]   \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.26459813117981\n",
      "epoch: 17, training loss: 27.902658, training acc: 0.938587, training depends: 0.685745, valid loss: 90.332116, valid acc: 0.796167, valid depends: 0.239845\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.08it/s, accuracy=0.959, accuracy_depends=0.856, cost=12.1]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.89it/s, accuracy=0.827, accuracy_depends=0.191, cost=102] \n",
      "train minibatch loop:   0%|          | 0/76 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.02435898780823\n",
      "epoch: 18, training loss: 24.752691, training acc: 0.943680, training depends: 0.727282, valid loss: 88.909203, valid acc: 0.802451, valid depends: 0.263541\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 76/76 [00:39<00:00,  2.12it/s, accuracy=0.976, accuracy_depends=0.877, cost=9.05]\n",
      "test minibatch loop: 100%|██████████| 54/54 [00:15<00:00,  3.92it/s, accuracy=0.836, accuracy_depends=0.155, cost=110] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.076401472091675\n",
      "epoch: 19, training loss: 21.722709, training acc: 0.951147, training depends: 0.767359, valid loss: 92.559914, valid acc: 0.800110, valid depends: 0.274829\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for e in range(20):\n",
    "    lasttime = time.time()\n",
    "    train_acc, train_loss, test_acc, test_loss, train_acc_depends, test_acc_depends = 0, 0, 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_char = train_char[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_y = train_Y[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_depends = train_depends[i : min(i + batch_size, train_X.shape[0])]\n",
    "        acc_depends, acc, cost, _ = sess.run(\n",
    "            [model.accuracy_depends, model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.word_ids: batch_x,\n",
    "                model.char_ids: batch_char,\n",
    "                model.labels: batch_y,\n",
    "                model.depends: batch_depends\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        train_acc_depends += acc_depends\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc, accuracy_depends = acc_depends)\n",
    "        \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_char = test_char[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_depends = test_depends[i : min(i + batch_size, test_X.shape[0])]\n",
    "        acc_depends, acc, cost = sess.run(\n",
    "            [model.accuracy_depends, model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.word_ids: batch_x,\n",
    "                model.char_ids: batch_char,\n",
    "                model.labels: batch_y,\n",
    "                model.depends: batch_depends\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        test_acc_depends += acc_depends\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc, accuracy_depends = acc_depends)\n",
    "    \n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    train_acc_depends /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "    test_acc_depends /= len(test_X) / batch_size\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, training depends: %f, valid loss: %f, valid acc: %f, valid depends: %f\\n'\n",
    "        % (e, train_loss, train_acc, train_acc_depends, test_loss, test_acc, test_acc_depends)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq, deps = sess.run([model.tags_seq, model.tags_seq_depends],\n",
    "        feed_dict={model.word_ids:batch_x[:1],\n",
    "                  model.char_ids:batch_char[:1]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq = seq[0]\n",
    "deps = deps[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([18, 19,  2,  6,  3,  4, 16, 18, 23, 20, 19,  2], dtype=int32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq[seq>0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([18, 19,  2,  6,  3,  7, 16, 18, 23, 20, 19,  2], dtype=int32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_y[0][seq>0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2,  4,  4,  4,  8,  8,  4, 10, 12, 12,  8,  4], dtype=int32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deps[seq>0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2,  6,  6,  5,  6,  0,  6, 11, 11, 11,  6,  6], dtype=int32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_depends[0][seq>0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
